import numpy as np
import matplotlib.pyplot as plt
from time import time

class SVM:
    def __init__(self, C=100, sigma=0.25, eps=1e-3, max_iter=100):
        self.C = C          # 惩罚系数
        self.sigma = sigma  # 高斯核参数
        self.eps = eps      # 停止阈值
        self.max_iter = max_iter
    
    def fit(self, X, y):
        self.X = X
        self.y = y
        self.n_samples = X.shape[0]
        
        # 初始化参数
        self.a = np.zeros(self.n_samples)
        self.b = 0.0
        self.E = -y.copy()  # E = f(x) - y，初始时f(x)=0
        
        # 预计算核矩阵（关键修正点）
        self.K = np.zeros((self.n_samples, self.n_samples))
        for i in range(self.n_samples):
            for j in range(i, self.n_samples):
                self.K[i,j] = self.kernel(X[i], X[j])
                self.K[j,i] = self.K[i,j]
        
        # SMO主循环
        for iter in range(self.max_iter):
            # 选择alpha对
            i1, i2 = self.select_alpha_pair()
            if i1 is None:
                break
            
            # 优化alpha对
            self.optimize_alpha_pair(i1, i2)
            
            # 提前停止检查
            if (iter+1) % 10 == 0:
                if self.check_stopping_condition():
                    break
        
        # 提取支持向量
        self.extract_support_vectors()
        print(f'找到 {len(self.sv_indices)} 个支持向量')
    
    def kernel(self, x1, x2):
        return np.exp(-np.sum((x1-x2)**2) / (2*self.sigma**2))
    
    def select_alpha_pair(self):
        # 第一层选择：违反KKT条件的样本
        violate_max = 0
        i1 = None
        
        # 遍历所有样本
        for i in range(self.n_samples):
            ai = self.a[i]
            yi = self.y[i]
            Ei = self.E[i]
            
            # KKT条件判断
            if (ai < self.C and yi*Ei < -self.eps) or (ai > 0 and yi*Ei > self.eps):
                # 计算违反程度
                violate = abs(yi*Ei)
                if violate > violate_max:
                    violate_max = violate
                    i1 = i
        
        if i1 is None:
            return None, None
        
        # 选择第二个alpha：最大|E1-E2|
        i2 = np.argmax(np.abs(self.E - self.E[i1]))
        if i2 == i1:
            i2 = (i1 + 1) % self.n_samples
        return i1, i2
    
    def optimize_alpha_pair(self, i1, i2):
        if i1 == i2:
            return
        
        # 获取相关参数
        y1, y2 = self.y[i1], self.y[i2]
        a1_old, a2_old = self.a[i1], self.a[i2]
        E1, E2 = self.E[i1], self.E[i2]
        K11 = self.K[i1,i1]
        K12 = self.K[i1,i2]
        K22 = self.K[i2,i2]
        
        # 计算上下界
        if y1 != y2:
            L = max(0, a2_old - a1_old)
            H = min(self.C, self.C + a2_old - a1_old)
        else:
            L = max(0, a2_old + a1_old - self.C)
            H = min(self.C, a2_old + a1_old)
        if L >= H:
            return
        
        # 计算eta
        eta = K11 + K22 - 2*K12
        if eta <= 0:
            return
        
        # 更新a2
        a2_new = a2_old + y2*(E1 - E2)/eta
        a2_new = np.clip(a2_new, L, H)
        
        # 更新a1
        a1_new = a1_old + y1*y2*(a2_old - a2_new)
        
        # 更新偏置b
        b1 = E1 + y1*(a1_new - a1_old)*K11 + y2*(a2_new - a2_old)*K12 + self.b
        b2 = E2 + y1*(a1_new - a1_old)*K12 + y2*(a2_new - a2_old)*K22 + self.b
        if 0 < a1_new < self.C:
            new_b = b1
        elif 0 < a2_new < self.C:
            new_b = b2
        else:
            new_b = (b1 + b2)/2
        
        # 更新参数
        self.a[i1] = a1_new
        self.a[i2] = a2_new
        self.b = new_b
        
        # 更新误差缓存（关键修正点）
        for i in range(self.n_samples):
            self.E[i] = self._fx(i) - self.y[i]
    
    def _fx(self, i):
        return np.dot(self.a * self.y, self.K[:,i]) - self.b
    
    def check_stopping_condition(self):
        # 检查对偶间隙
        dual_gap = np.sum(self.a) - 0.5*np.sum(np.outer(self.a*self.y, self.a*self.y)*self.K)
        if dual_gap < self.eps:
            return True
        return False
    
    def extract_support_vectors(self):
        self.sv_indices = np.where(self.a > 1e-5)[0]
        self.sv_X = self.X[self.sv_indices]
        self.sv_y = self.y[self.sv_indices]
        self.sv_a = self.a[self.sv_indices]
    
    def predict(self, X):
        n = X.shape[0]
        y_pred = np.zeros(n)
        for i in range(n):
            k = np.array([self.kernel(X[i], sv) for sv in self.sv_X])
            y_pred[i] = np.sign(np.dot(self.sv_a * self.sv_y, k) - self.b)
        return y_pred

# 数据生成函数
def generate_moon_data(n_train=300, n_val=2000):
    r, wth, dis = 10, 6, -6.5
    total = n_train + n_val
    X = np.zeros((total, 2))
    y = np.zeros(total)
    
    for i in range(total):
        if np.random.rand() > 0.5:
            theta = np.pi * np.random.rand()
            length = r + wth*(np.random.rand()-0.5)
            X[i] = [length*np.cos(theta), length*np.sin(theta)]
            y[i] = -1
        else:
            beta = -np.pi * np.random.rand()
            length = r + wth*(np.random.rand()-0.5)
            X[i] = [length*np.cos(beta)+r, length*np.sin(beta)-dis]
            y[i] = 1
    return X[:n_train], y[:n_train], X[n_train:], y[n_train:]

# 数据预处理
def normalize(X_train, X_val):
    # 统一归一化参数
    miu = np.mean(np.vstack([X_train, X_val]), axis=0)
    max_val = np.max(np.abs(np.vstack([X_train, X_val])), axis=0)
    
    X_train_norm = (X_train - miu) / max_val
    X_val_norm = (X_val - miu) / max_val
    return X_train_norm, X_val_norm

# 可视化函数
def visualize(model, X, y):
    plt.figure(figsize=(10,8))
    
    # 绘制训练数据
    pos = y == 1
    neg = y == -1
    plt.scatter(X[pos,0], X[pos,1], c='b', label='Class 1', alpha=0.5)
    plt.scatter(X[neg,0], X[neg,1], c='g', label='Class -1', alpha=0.5)
    
    if len(model.sv_indices) > 0:
        # 绘制支持向量
        plt.scatter(model.sv_X[:,0], model.sv_X[:,1], s=80, 
                    facecolors='none', edgecolors='r', linewidths=2, label='Support Vectors')
        
        # 绘制决策边界
        x_min, x_max = X[:,0].min()-0.1, X[:,0].max()+0.1
        y_min, y_max = X[:,1].min()-0.1, X[:,1].max()+0.1
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
                             np.linspace(y_min, y_max, 200))
        Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
        
        # 仅在有支持向量时绘制
        plt.contour(xx, yy, Z, colors='r', levels=[0], linewidths=2)
    
    plt.title('SVM Classification Result')
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.legend()
    plt.show()

if __name__ == "__main__":
    # 生成数据
    X_train, y_train, X_val, y_val = generate_moon_data(n_train=300, n_val=2000)
    
    # 数据归一化
    X_train, X_val = normalize(X_train, X_val)
    
    # 训练模型（调整参数）
    start = time()
    svm = SVM(C=10, sigma=0.2, max_iter=200)  # 调整参数
    svm.fit(X_train, y_train)
    
    # 验证模型
    y_pred = svm.predict(X_val)
    error_rate = np.mean(y_pred != y_val)
    
    print(f'运行时间: {time()-start:.2f}秒')
    print(f'验证集错误率: {error_rate*100:.2f}%')
    
    # 可视化结果
    visualize(svm, X_train, y_train)